﻿// Accord Machine Learning Library
// The Accord.NET Framework
// http://accord-framework.net
//
// Copyright © César Souza, 2009-2017
// cesarsouza at gmail.com
//
//    This library is free software; you can redistribute it and/or
//    modify it under the terms of the GNU Lesser General Public
//    License as published by the Free Software Foundation; either
//    version 2.1 of the License, or (at your option) any later version.
//
//    This library is distributed in the hope that it will be useful,
//    but WITHOUT ANY WARRANTY; without even the implied warranty of
//    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
//    Lesser General Public License for more details.
//
//    You should have received a copy of the GNU Lesser General Public
//    License along with this library; if not, write to the Free Software
//    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
//

namespace Accord.MachineLearning.Performance
{
    using System;
    using System.Linq;
    using Accord.Compat;

    /// <summary>
    ///   Class for representing results acquired through a 
    ///   <see cref="CrossValidation{TModel,TInput,TOutput}">k-fold cross-validation analysis</see>.
    /// </summary>
    /// 
    /// <typeparam name="TModel">The type of the machine learning model.</typeparam>
    /// <typeparam name="TInput">The type of the input data.</typeparam>
    /// <typeparam name="TOutput">The type of the output data or labels.</typeparam>
    /// 
    /// <example>
    ///   <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn" />
    ///   <code source="Unit Tests\Accord.Tests.MachineLearning\CrossValidationTest.cs" region="doc_learn_hmm" />
    ///   <code source="Unit Tests\Accord.Tests.MachineLearning\DecisionTrees\DecisionTreeTest.cs" region="doc_cross_validation" />
    ///   <code source="Unit Tests\Accord.Tests.MachineLearning\Bayes\NaiveBayesTest.cs" region="doc_cross_validation" />
    /// </example>
    /// 
    [Serializable]
    public class CrossValidationResult<TModel, TInput, TOutput> : TrainValSplit<CrossValidationStatistics>,
        ITransform<TInput, TOutput>
        where TModel : class, ITransform<TInput, TOutput>
    {

        private Func<TOutput[], TOutput> combineMethod;

        /// <summary>
        ///   Gets the total number of data samples in the entire data set.
        /// </summary>
        /// 
        public int NumberOfSamples
        {
            get { return Models.Select(x => x.Validation.NumberOfSamples).Sum(); }
        }

        /// <summary>
        ///   Gets the average number of data samples in 
        ///   each cross-validation fold of the data set.
        /// </summary>
        /// 
        public double AverageNumberOfSamples
        {
            get { return Models.Select(x => x.Validation.NumberOfSamples).Average(); }
        }

        /// <summary>
        ///   Gets the models created for each fold of the cross validation.
        /// </summary>
        /// 
        public SplitResult<TModel, TInput, TOutput>[] Models { get; private set; }

        /// <summary>
        ///   Gets or sets a tag for user-defined information.
        /// </summary>
        /// 
        public object Tag { get; set; }

        /// <summary>
        /// Gets the number of inputs accepted by the model.
        /// </summary>
        /// <value>The number of inputs.</value>
        /// <exception cref="System.ArgumentException">This property is read only.</exception>
        public int NumberOfInputs
        {
            get { return Models[0].NumberOfInputs; }
            set { throw new ArgumentException("This property is read only."); }
        }

        /// <summary>
        /// Gets the number of outputs generated by the model.
        /// </summary>
        /// <value>The number of outputs.</value>
        /// <exception cref="System.ArgumentException">This property is read only.</exception>
        public int NumberOfOutputs
        {
            get { return Models[0].NumberOfOutputs; }
            set { throw new ArgumentException("This property is read only."); }
        }

        /// <summary>
        ///   Initializes a new instance of the <see cref="CrossValidationResult{TModel}"/> class.
        /// </summary>
        /// 
        /// <param name="models">The models created during the cross-validation runs.</param>
        /// 
        public CrossValidationResult(SplitResult<TModel, TInput, TOutput>[] models)
        {
            double[] trainingValues = models.Select(x => x.Training.Value).ToArray();
            double[] trainingVariances = models.Select(x => x.Training.Variance).ToArray();
            int[] trainingCount = models.Select(x => x.Training.NumberOfSamples).ToArray();

            double[] validationValues = models.Select(x => x.Validation.Value).ToArray();
            double[] validationVariances = models.Select(x => x.Validation.Variance).ToArray();
            int[] validationCount = models.Select(x => x.Validation.NumberOfSamples).ToArray();

            this.Models = models;
            this.Training = new CrossValidationStatistics(trainingCount, trainingValues, trainingVariances);
            this.Validation = new CrossValidationStatistics(validationCount, validationValues, validationVariances);
        }

        /// <summary>
        /// Applies the transformation to an input, producing an associated output.
        /// </summary>
        /// <param name="input">The input data to which the transformation should be applied.</param>
        /// <returns>The output generated by applying this transformation to the given input.</returns>
        /// <exception cref="System.Exception">Please specify how the results of the different models should be combined by setting the CombineMethod property.</exception>
        public TOutput Transform(TInput input)
        {
            if (CombineMethod == null)
                throw new Exception("Please specify how the results of the different models should be combined by setting the CombineMethod property.");

            var output = new TOutput[Models.Length];
            for (int i = 0; i < Models.Length; i++)
                output[i] = Models[i].Model.Transform(input);
            return CombineMethod(output);
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        public TOutput[] Transform(TInput[] input)
        {
            return Transform(input, new TOutput[input.Length]);
        }

        /// <summary>
        /// Applies the transformation to a set of input vectors,
        /// producing an associated set of output vectors.
        /// </summary>
        /// <param name="input">The input data to which
        /// the transformation should be applied.</param>
        /// <param name="result">The location to where to store the
        /// result of this transformation.</param>
        /// <returns>The output generated by applying this
        /// transformation to the given input.</returns>
        /// <exception cref="System.Exception">Please specify how the results of the different models should be combined by setting the CombineMethod property.</exception>
        public TOutput[] Transform(TInput[] input, TOutput[] result)
        {
            if (CombineMethod == null)
                throw new Exception("Please specify how the results of the different models should be combined by setting the CombineMethod property.");

            var output = new TOutput[Models.Length];
            for (int j = 0; j < input.Length; j++)
            {
                for (int i = 0; i < Models.Length; i++)
                    output[i] = Models[i].Model.Transform(input[i]);
                result[j] = CombineMethod(output);
            }

            return result;
        }

        /// <summary>
        ///   Gets or sets the method used to combine the scores of different classifiers. 
        /// </summary>
        /// 
        public Func<TOutput[], TOutput> CombineMethod
        {
            get { return combineMethod; }
            set { combineMethod = value; }
        }


    }
}
